{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kF6XR0AMAuRW"
   },
   "source": [
    "# Week 2 - Ungraded Lab: A journey through Data\n",
    "\n",
    "Welcome to the ungraded lab for week 2 of Machine Learning Engineering for Production. **The paradigm behind Deep Learning is now facing a shift from model-centric to data-centric.** In this lab you will see how data intricacies affect the outcome of your models. To show you how far it will take you to apply data changes without addressing the model, you will be using a single model throughout: a simple Convolutional Neural Network (CNN). While training this model the journey will take you to address common problems: class imbalance and overfitting. As you navigate these issues, the lab will walk you through useful diagnosis tools and methods to mitigate these common problems.\n",
    "\n",
    "-------\n",
    "-------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lao0CVv7c3Rd"
   },
   "source": [
    "### **IMPORTANT NOTES BEFORE STARTING THE LAB**\n",
    "\n",
    "Once opened in Colab, click on the \"Connect\" button on the upper right side corner of the screen to connect to a runtime to run this lab.\n",
    "\n",
    "\n",
    "**NOTE 1:**\n",
    "\n",
    "For this lab you get the option to either train the models yourself (this takes around 20 minutes with GPU enabled for each model) or to use pretrained versions which are already provided. There are a total of 3 CNNs that require training and although some parameters have been tuned to provide a faster training time (such as `steps_per_epoch` and `validation_steps` which have been heavily lowered) this may result in a long time spent running this lab rather than thinking about what you observe.\n",
    "\n",
    "To speed things up we have provided saved pre-trained versions of each model along with their respective training history. We recommend you use these pre-trained versions to save time. However we also consider that training a model is an important learning experience especially if you haven't done this before. **If you want to perform this training by yourself, the code for replicating the training is provided as well. In this case the GPU is absolutely necessary, so be sure that it is enabled.**\n",
    "\n",
    "To make sure your runtime is GPU you can go to Runtime -> Change runtime type -> Select GPU from the menu and then press SAVE\n",
    "\n",
    "- Note: Restarting the runtime may\n",
    "be required.\n",
    "\n",
    "- Colab will tell you if restarting is necessary -- you can do this from Runtime -> Restart Runtime option in the dropdown.\n",
    "\n",
    "**If you decide to use the pretrained versions make sure you are not using a GPU as it is not required and may prevent other users from getting access to one.** To check this, go to Runtime -> Change runtime type -> Select None from the menu and then press SAVE.\n",
    "\n",
    "**NOTE 2:**\n",
    "\n",
    "Colab **does not** guarantee access to a GPU. This depends on the availability of these resources. However **it is not very common to be denied GPU access**. If this happens to you, you can still run this lab without training the models yourself. If you really want to do the training but are denied a GPU, try switching the runtime to a GPU after a couple of hours.\n",
    "\n",
    "To know more about Colab's policies check out this [FAQ](https://research.google.com/colaboratory/faq.html).\n",
    "\n",
    "-----------\n",
    "-----------\n",
    "\n",
    "Let's get started!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LttdbzB5XB0O"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import random\n",
    "import zipfile\n",
    "import tarfile\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# To ignore some warnings about Image metadata that Pillow prints out\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v4Gq9Xffccwt"
   },
   "source": [
    "Before you move on, download the two datasets used in the lab, as well as the pretrained models and histories:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CkTzJYihXWu3"
   },
   "outputs": [],
   "source": [
    "# Download datasets\n",
    "\n",
    "# Cats and dogs\n",
    "!wget https://storage.googleapis.com/mlep-public/course_1/week2/kagglecatsanddogs_3367a.zip\n",
    "\n",
    "# Caltech birds\n",
    "!wget https://storage.googleapis.com/mlep-public/course_1/week2/CUB_200_2011.tar\n",
    "\n",
    "# Download pretrained models and training histories\n",
    "!wget -q -P /content/model-balanced/ https://storage.googleapis.com/mlep-public/course_1/week2/model-balanced/saved_model.pb\n",
    "!wget -q -P /content/model-balanced/variables/ https://storage.googleapis.com/mlep-public/course_1/week2/model-balanced/variables/variables.data-00000-of-00001\n",
    "!wget -q -P /content/model-balanced/variables/ https://storage.googleapis.com/mlep-public/course_1/week2/model-balanced/variables/variables.index\n",
    "!wget -q -P /content/history-balanced/ https://storage.googleapis.com/mlep-public/course_1/week2/history-balanced/history-balanced.csv\n",
    "\n",
    "!wget -q -P /content/model-imbalanced/ https://storage.googleapis.com/mlep-public/course_1/week2/model-imbalanced/saved_model.pb\n",
    "!wget -q -P /content/model-imbalanced/variables/ https://storage.googleapis.com/mlep-public/course_1/week2/model-imbalanced/variables/variables.data-00000-of-00001\n",
    "!wget -q -P /content/model-imbalanced/variables/ https://storage.googleapis.com/mlep-public/course_1/week2/model-imbalanced/variables/variables.index\n",
    "!wget -q -P /content/history-imbalanced/ https://storage.googleapis.com/mlep-public/course_1/week2/history-imbalanced/history-imbalanced.csv\n",
    "\n",
    "!wget -q -P /content/model-augmented/ https://storage.googleapis.com/mlep-public/course_1/week2/model-augmented/saved_model.pb\n",
    "!wget -q -P /content/model-augmented/variables/ https://storage.googleapis.com/mlep-public/course_1/week2/model-augmented/variables/variables.data-00000-of-00001\n",
    "!wget -q -P /content/model-augmented/variables/ https://storage.googleapis.com/mlep-public/course_1/week2/model-augmented/variables/variables.index\n",
    "!wget -q -P /content/history-augmented/ https://storage.googleapis.com/mlep-public/course_1/week2/history-augmented/history-augmented.csv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "suKuIsOYdC9G"
   },
   "source": [
    "## A story of data\n",
    "\n",
    "To guide you through this lab we have prepared a narrative that simulates a real life scenario:\n",
    "\n",
    "Suppose you have been tasked to create a model that classifies images of cats, dogs and birds. For this you settle on a simple CNN architecture, since CNN's are known to perform well for image classification. You are probably familiar with two widely used datasets: `cats vs dogs`, and `caltech birds`. As a side note both datasets are available through `Tensforflow Datasets (TFDS)`. However, you decide NOT to use `TFDS` since the lab requires you to modify the data and combine the two datasets into one. \n",
    "\n",
    "## Combining the datasets\n",
    "\n",
    "The raw images in these datasets can be found within the following paths:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-ja5V3AbYCp8"
   },
   "outputs": [],
   "source": [
    "cats_and_dogs_zip = '/content/kagglecatsanddogs_3367a.zip'\n",
    "caltech_birds_tar = '/content/CUB_200_2011.tar'\n",
    "\n",
    "base_dir = '/tmp/data'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xRqfAVn6e8Lp"
   },
   "source": [
    "The next step is extracting the data into a directory of choice, `base_dir` in this case.\n",
    "\n",
    "Note that the `cats vs dogs` images are in `zip` file format while the `caltech birds` images come in a `tar` file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aUl3_4nVXcsE"
   },
   "outputs": [],
   "source": [
    "with zipfile.ZipFile(cats_and_dogs_zip, 'r') as my_zip:\n",
    "  my_zip.extractall(base_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JQYh7tAyqOA7"
   },
   "outputs": [],
   "source": [
    "with tarfile.open(caltech_birds_tar, 'r') as my_tar:\n",
    "  my_tar.extractall(base_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "65E3t5Qlfwwn"
   },
   "source": [
    "For the cats and dogs images no further preprocessing is needed as all exemplars of a single class are located in one directory: `PetImages\\Cat` and `PetImages\\Dog` respectively. Let's check how many images are available for each category:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "husRshAjYim9"
   },
   "outputs": [],
   "source": [
    "base_dogs_dir = os.path.join(base_dir, 'PetImages/Dog')\n",
    "base_cats_dir = os.path.join(base_dir,'PetImages/Cat')\n",
    "\n",
    "print(f\"There are {len(os.listdir(base_dogs_dir))} images of dogs\")\n",
    "print(f\"There are {len(os.listdir(base_cats_dir))} images of cats\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oqiG9G7-g2Z1"
   },
   "source": [
    "The Bird images dataset organization is quite different. This dataset is commonly used to classify species of birds so there is a directory for each species. Let's treat all species of birds as a single class. This requires moving all bird images to a single directory (`PetImages/Bird` will be used for consistency). This can be done by running the next cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ifcKshS6xmVj"
   },
   "outputs": [],
   "source": [
    "raw_birds_dir = '/tmp/data/CUB_200_2011/images'\n",
    "base_birds_dir = os.path.join(base_dir,'PetImages/Bird')\n",
    "os.mkdir(base_birds_dir)\n",
    "\n",
    "for subdir in os.listdir(raw_birds_dir):\n",
    "  subdir_path = os.path.join(raw_birds_dir, subdir)\n",
    "  for image in os.listdir(subdir_path):\n",
    "    shutil.move(os.path.join(subdir_path, image), os.path.join(base_birds_dir))\n",
    "\n",
    "print(f\"There are {len(os.listdir(base_birds_dir))} images of birds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9tteiK1fieHo"
   },
   "source": [
    "It turns out that there is a similar number of images for each class you are trying to predict! Nice!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z3jHPdb7SE61"
   },
   "source": [
    "Let's take a quick look at an image of each class you are trying to predict."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lXE9RlF2ZFLL"
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "\n",
    "print(\"Sample cat image:\")\n",
    "display(Image(filename=f\"{os.path.join(base_cats_dir, os.listdir(base_cats_dir)[0])}\"))\n",
    "print(\"\\nSample dog image:\")\n",
    "display(Image(filename=f\"{os.path.join(base_dogs_dir, os.listdir(base_dogs_dir)[0])}\"))\n",
    "print(\"\\nSample bird image:\")\n",
    "display(Image(filename=f\"{os.path.join(base_birds_dir, os.listdir(base_birds_dir)[0])}\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FduWstcripzJ"
   },
   "source": [
    "## Train / Evaluate Split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EiL9L8eSizCp"
   },
   "source": [
    "Before training the model you need to split the data into `training` and `evaluating` sets. For training, we have chosen the [`Keras`](https://keras.io) application programming interface (API) which includes functionality to read images from  various directories. The easier way to split the data is to create a different directory for each split of each class.\n",
    "\n",
    "Run the next cell to create the directories for training and evaluating sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NdBnzB2Mvcs2"
   },
   "outputs": [],
   "source": [
    "train_eval_dirs = ['train/cats', 'train/dogs', 'train/birds',\n",
    "                   'eval/cats', 'eval/dogs', 'eval/birds']\n",
    "\n",
    "for dir in train_eval_dirs:\n",
    "  if not os.path.exists(os.path.join(base_dir, dir)):\n",
    "    os.makedirs(os.path.join(base_dir, dir))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "x4XYN51Zj7-J"
   },
   "source": [
    "Now, let's define a function that will move a percentage of images from an origin folder to a destination folder as desired to generate the training and evaluation splits:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DRpbU9HAdn4n"
   },
   "outputs": [],
   "source": [
    "def move_to_destination(origin, destination, percentage_split):\n",
    "  num_images = int(len(os.listdir(origin))*percentage_split)\n",
    "  for image_name, image_number in zip(sorted(os.listdir(origin)), range(num_images)):\n",
    "    shutil.move(os.path.join(origin, image_name), destination)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DfssLKoathoG"
   },
   "source": [
    "And now you are ready to call the previous function and split the data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VMKvQGH6fGdW"
   },
   "outputs": [],
   "source": [
    "# Move 70% of the images to the train dir\n",
    "move_to_destination(base_cats_dir, os.path.join(base_dir, 'train/cats'), 0.7)\n",
    "move_to_destination(base_dogs_dir, os.path.join(base_dir, 'train/dogs'), 0.7)\n",
    "move_to_destination(base_birds_dir, os.path.join(base_dir, 'train/birds'), 0.7)\n",
    "\n",
    "\n",
    "# Move the remaining images to the eval dir\n",
    "move_to_destination(base_cats_dir, os.path.join(base_dir, 'eval/cats'), 1)\n",
    "move_to_destination(base_dogs_dir, os.path.join(base_dir, 'eval/dogs'), 1)\n",
    "move_to_destination(base_birds_dir, os.path.join(base_dir, 'eval/birds'), 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0eAD4J1ukGYC"
   },
   "source": [
    "Something important to mention is that as it currently stands your dataset has some issues that will prevent model training and evaluation. Mainly:\n",
    "\n",
    "1. Some images are corrupted and have zero bytes.\n",
    "2. Cats vs dogs zip file included a `.db` file for each class that needs to be deleted.\n",
    "\n",
    "If you didn't fix this before training you will get errors regarding these issues and training will fail. Zero-byte images are not valid images and Keras will let you know once these files are reached. In a similar way `.db` files are not valid images. **It is a good practice to always make sure that you are submitting files with the correct specifications to your training algorithm before start running it** as these issues might not be encountered right away and you will have to solve them and start training again.\n",
    "\n",
    "Running the following `bash` commands in the base directory will resolve these issues:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3An_dEi0hwHj"
   },
   "outputs": [],
   "source": [
    "!find /tmp/data/ -size 0 -exec rm {} +\n",
    "!find /tmp/data/ -type f ! -name \"*.jpg\" -exec rm {} +"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oeqbprKcmr-0"
   },
   "source": [
    "The first command removes all zero-byte files from the filesystem. The second one removes any file that does not have a `.jpg` extension. \n",
    "\n",
    "This also serves as a reminder of the power of bash. Although you could achieve the same result with Python code, bash allows you to do this much quicker. If you are not familiar with bash or some other shell-like language we encourage you to learn some of it as it is a very useful tool for data manipulation purposes.\n",
    "\n",
    "Let's check how many images you have available for each split and class after you remove the corrupted images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nZFk4f0jhEAk"
   },
   "outputs": [],
   "source": [
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'train/cats')))} images of cats for training\")\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'train/dogs')))} images of dogs for training\")\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'train/birds')))} images of birds for training\\n\")\n",
    "\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'eval/cats')))} images of cats for evaluation\")\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'eval/dogs')))} images of dogs for evaluation\")\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'eval/birds')))} images of birds for evaluation\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LSmRaN_Qm-s4"
   },
   "source": [
    "It turns out that very few files presented the issues mentioned above. That's good news but it is also a reminder that small problems with the dataset might unexpectedly affect the training process. In this case, 4 non valid image files will have prevented you from training the model. \n",
    "\n",
    "In most cases training Deep Learning models is a time intensive task, so be sure to have everything in place before starting this process.\n",
    "\n",
    "\n",
    "## An unexpected issue!\n",
    "\n",
    "Let's face the first real life issue in this narrative! There was a power outage in your office and some hard drives were damaged and as a result of that, many of the images for `dogs` and `birds` have been erased. As a matter of fact, only 20% of the dog images and 10% of the bird images survived.\n",
    "\n",
    "To simulate this scenario, let's quickly create a new directory called `imbalanced` and copy only the proportions mentioned above for each class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wAG-rJRPZTQt"
   },
   "outputs": [],
   "source": [
    "for dir in train_eval_dirs:\n",
    "  if not os.path.exists(os.path.join(base_dir, 'imbalanced/'+dir)):\n",
    "    os.makedirs(os.path.join(base_dir, 'imbalanced/'+dir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GAGTj51qZT4e"
   },
   "outputs": [],
   "source": [
    "# Very similar to the one used before but this one copies instead of moving\n",
    "def copy_with_limit(origin, destination, percentage_split):\n",
    "  num_images = int(len(os.listdir(origin))*percentage_split)\n",
    "  for image_name, image_number in zip(sorted(os.listdir(origin)), range(num_images)):\n",
    "    shutil.copy(os.path.join(origin, image_name), destination)\n",
    "\n",
    "# Perform the copying\n",
    "copy_with_limit(os.path.join(base_dir, 'train/cats'), os.path.join(base_dir, 'imbalanced/train/cats'), 1)\n",
    "copy_with_limit(os.path.join(base_dir, 'train/dogs'), os.path.join(base_dir, 'imbalanced/train/dogs'), 0.2)\n",
    "copy_with_limit(os.path.join(base_dir, 'train/birds'), os.path.join(base_dir, 'imbalanced/train/birds'), 0.1)\n",
    "\n",
    "copy_with_limit(os.path.join(base_dir, 'eval/cats'), os.path.join(base_dir, 'imbalanced/eval/cats'), 1)\n",
    "copy_with_limit(os.path.join(base_dir, 'eval/dogs'), os.path.join(base_dir, 'imbalanced/eval/dogs'), 0.2)\n",
    "copy_with_limit(os.path.join(base_dir, 'eval/birds'), os.path.join(base_dir, 'imbalanced/eval/birds'), 0.1)\n",
    "\n",
    "# Print number of available images\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'imbalanced/train/cats')))} images of cats for training\")\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'imbalanced/train/dogs')))} images of dogs for training\")\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'imbalanced/train/birds')))} images of birds for training\\n\")\n",
    "\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'imbalanced/eval/cats')))} images of cats for evaluation\")\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'imbalanced/eval/dogs')))} images of dogs for evaluation\")\n",
    "print(f\"There are {len(os.listdir(os.path.join(base_dir, 'imbalanced/eval/birds')))} images of birds for evaluation\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2Qt_EGGJAaOR"
   },
   "source": [
    "For now there is no quick or clear solution to the accidental file loss. So you decide to keep going and train the model with the remaining images."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qlDuR43ZAfwk"
   },
   "source": [
    "## Selecting the model\n",
    "\n",
    "Let's go ahead and create a model architecture and define a loss function, optimizer and performance metrics leveraging keras API:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AiTGrTiHZ9fS"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras import layers, models, optimizers\n",
    "\n",
    "def create_model():\n",
    "  # A simple CNN architecture based on the one found here: https://www.tensorflow.org/tutorials/images/classification\n",
    "  model = models.Sequential([\n",
    "  layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),\n",
    "  layers.MaxPooling2D((2, 2)),\n",
    "  layers.Conv2D(64, (3, 3), activation='relu'),\n",
    "  layers.MaxPooling2D((2, 2)),\n",
    "  layers.Conv2D(64, (3, 3), activation='relu'),\n",
    "  layers.MaxPooling2D((2, 2)),\n",
    "  layers.Conv2D(128, (3, 3), activation='relu'),\n",
    "  layers.MaxPooling2D((2, 2)),\n",
    "  layers.Flatten(),\n",
    "  layers.Dense(512, activation='relu'),\n",
    "  layers.Dense(3, activation='softmax')\n",
    "  ])\n",
    "\n",
    "\n",
    "  # Compile the model\n",
    "  model.compile(\n",
    "      loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
    "      optimizer=optimizers.Adam(),\n",
    "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]\n",
    "  )\n",
    "\n",
    "  return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UVj-I-Ke03Au"
   },
   "source": [
    "And let's print out a model summary as a quick check."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "elM3J9P8I_zu"
   },
   "outputs": [],
   "source": [
    "# Create a model to use with the imbalanced dataset\n",
    "imbalanced_model = create_model()\n",
    "\n",
    "# Print the model's summary\n",
    "print(imbalanced_model.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9YjjV9iU78Ca"
   },
   "source": [
    "For training the model you will be using Keras' ImageDataGenerator, which has built-in functionalities to easily feed your model with raw, rescaled or even augmented image data.\n",
    "\n",
    "Another cool functionality within ImageDataGenerator is the `flow_from_directory` method which allows to read images as needed from a root directory. This method needs the following arguments:\n",
    "\n",
    "- `directory`: Path to the root directory where the images are stored.\n",
    "- `target_size`: The dimensions to which all images found will be resized. Since images come in all kinds of resolutions, you need to standardize their size. 150x150 is used but other values should work well too.\n",
    "- `batch_size`: Number of images the generator yields everytime it is asked for a next batch. 32 is used here.\n",
    "- `class_mode`: How the labels are represented. Here \"binary\" is used to indicate that labels will be 1D. This is done for compatibility with the loss and evaluation metrics used when compiling the model.\n",
    "\n",
    "If you want to learn more about using Keras' ImageDataGenerator, check this [tutorial](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4SyU0P66azNE"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "\n",
    "# No data augmentation for now, only normalizing pixel values\n",
    "train_datagen = ImageDataGenerator(rescale=1./255)\n",
    "test_datagen = ImageDataGenerator(rescale=1./255)\n",
    "\n",
    "# Point to the imbalanced directory\n",
    "train_generator = train_datagen.flow_from_directory(\n",
    "        '/tmp/data/imbalanced/train',\n",
    "        target_size=(150, 150),\n",
    "        batch_size=32,\n",
    "        class_mode='binary')\n",
    "\n",
    "validation_generator = test_datagen.flow_from_directory(\n",
    "        '/tmp/data/imbalanced/eval',\n",
    "        target_size=(150, 150),\n",
    "        batch_size=32,\n",
    "        class_mode='binary')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NsowgcmDAOv-"
   },
   "source": [
    "Let's do a quick sanity check to inspect that both generators (training and validation) use the same labels for each class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MlCgRwvWX8BO"
   },
   "outputs": [],
   "source": [
    "print(f\"labels for each class in the train generator are: {train_generator.class_indices}\")\n",
    "print(f\"labels for each class in the validation generator are: {validation_generator.class_indices}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UqXdzv-soUzj"
   },
   "source": [
    "\n",
    "# Training a CNN with class imbalanced data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "O1DI3mKCraJQ"
   },
   "outputs": [],
   "source": [
    "# Load pretrained model and history\n",
    "\n",
    "imbalanced_history = pd.read_csv('history-imbalanced/history-imbalanced.csv')\n",
    "imbalanced_model = tf.keras.models.load_model('model-imbalanced')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UUhtEnsgxZ00"
   },
   "outputs": [],
   "source": [
    "# Run only if you want to train the model yourself (this takes around 20 mins with GPU enabled)\n",
    "\n",
    "# imbalanced_history = imbalanced_model.fit(\n",
    "#     train_generator,\n",
    "#     steps_per_epoch=100,\n",
    "#     epochs=50,\n",
    "#     validation_data=validation_generator,\n",
    "#     validation_steps=80)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9kHwAYLvEhiQ"
   },
   "source": [
    "To analyze the model performance properly, it is important to track different metrics such as accuracy and loss function along the training process. Let's define a helper function to handle the metrics through the training history,depending on the method you previously selected:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kmoJLjoTzb_L"
   },
   "outputs": [],
   "source": [
    "def get_training_metrics(history):\n",
    "  \n",
    "  # This is needed depending on if you used the pretrained model or you trained it yourself\n",
    "  if not isinstance(history, pd.core.frame.DataFrame):\n",
    "    history = history.history\n",
    "  \n",
    "  acc = history['sparse_categorical_accuracy']\n",
    "  val_acc = history['val_sparse_categorical_accuracy']\n",
    "\n",
    "  loss = history['loss']\n",
    "  val_loss = history['val_loss']\n",
    "\n",
    "  return acc, val_acc, loss, val_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8OKRhD87E-V3"
   },
   "source": [
    "Now, let's plot the metrics and losses for each training epoch as the training process progresses. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RcYuJgrr11h4"
   },
   "outputs": [],
   "source": [
    "def plot_train_eval(history):\n",
    "  acc, val_acc, loss, val_loss = get_training_metrics(history)\n",
    "\n",
    "  acc_plot = pd.DataFrame({\"training accuracy\":acc, \"evaluation accuracy\":val_acc})\n",
    "  acc_plot = sns.lineplot(data=acc_plot)\n",
    "  acc_plot.set_title('training vs evaluation accuracy')\n",
    "  acc_plot.set_xlabel('epoch')\n",
    "  acc_plot.set_ylabel('sparse_categorical_accuracy')\n",
    "  plt.show()\n",
    "\n",
    "  print(\"\")\n",
    "\n",
    "  loss_plot = pd.DataFrame({\"training loss\":loss, \"evaluation loss\":val_loss})\n",
    "  loss_plot = sns.lineplot(data=loss_plot)\n",
    "  loss_plot.set_title('training vs evaluation loss')\n",
    "  loss_plot.set_xlabel('epoch')\n",
    "  loss_plot.set_ylabel('loss')\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "\n",
    "plot_train_eval(imbalanced_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4mF4fltDFM6o"
   },
   "source": [
    "From these two plots is quite evident that the model is overfitting the training data. However, the evaluation accuracy is still pretty high. Maybe class imbalance is not such a big issue after all. Perhaps this is too good to be true. \n",
    "\n",
    "Let's dive a little deeper, and compute some additional metrics to explore if the class imbalance is hampering the model to perform well. In particular, let's compare: the accuracy score,  the accuracy score balanced, and the confusion matrix.  Information on the accuracy scores calculations is provided in the [sklearn](https://scikit-learn.org/stable/modules/model_evaluation.html#classification-metrics) documentation. To refresh ideas on what is a confusion matrix check [Wikipedia](https://en.wikipedia.org/wiki/Confusion_matrix)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kB_8ipYTK6FF"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score, balanced_accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QQRokFyn-KIN"
   },
   "outputs": [],
   "source": [
    "# Use the validation generator without shuffle to easily compute additional metrics\n",
    "val_gen_no_shuffle = test_datagen.flow_from_directory(\n",
    "    '/tmp/data/imbalanced/eval',\n",
    "    target_size=(150, 150),\n",
    "    batch_size=32,\n",
    "    class_mode='binary',\n",
    "    shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yJEg83EIW_jm"
   },
   "outputs": [],
   "source": [
    "# Get the true labels from the generator\n",
    "y_true = val_gen_no_shuffle.classes\n",
    "\n",
    "# Use the model to predict (will take a couple of minutes)\n",
    "predictions_imbalanced = imbalanced_model.predict(val_gen_no_shuffle)\n",
    "\n",
    "# Get the argmax (since softmax is being used)\n",
    "y_pred_imbalanced = np.argmax(predictions_imbalanced, axis=1)\n",
    "\n",
    "# Print accuracy score\n",
    "print(f\"Accuracy Score: {accuracy_score(y_true, y_pred_imbalanced)}\")\n",
    "\n",
    "# Print balanced accuracy score\n",
    "print(f\"Balanced Accuracy Score: {balanced_accuracy_score(y_true, y_pred_imbalanced)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cXQQR9D8HVUh"
   },
   "source": [
    "Comparing the `accuracy` and `balanced accuracy` metrics, the class imbalance starts to become apparent. Now let's compute the `confusion matrix` of the predictions. Notice that the class imbalance is also present in the evaluation set so the confusion matrix will show an overwhelming majority for cats."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zZqpe9uLN2k0"
   },
   "outputs": [],
   "source": [
    "imbalanced_cm = confusion_matrix(y_true, y_pred_imbalanced)\n",
    "ConfusionMatrixDisplay(imbalanced_cm, display_labels=['birds', 'cats', 'dogs']).plot(values_format=\"d\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nu3xXDhYAnqL"
   },
   "outputs": [],
   "source": [
    "misclassified_birds = (imbalanced_cm[1,0] + imbalanced_cm[2,0])/np.sum(imbalanced_cm, axis=0)[0]\n",
    "misclassified_cats = (imbalanced_cm[0,1] + imbalanced_cm[2,1])/np.sum(imbalanced_cm, axis=0)[1]\n",
    "misclassified_dogs = (imbalanced_cm[0,2] + imbalanced_cm[1,2])/np.sum(imbalanced_cm, axis=0)[2]\n",
    "\n",
    "print(f\"Proportion of misclassified birds: {misclassified_birds*100:.2f}%\")\n",
    "print(f\"Proportion of misclassified cats: {misclassified_cats*100:.2f}%\")\n",
    "print(f\"Proportion of misclassified dogs: {misclassified_dogs*100:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e3tpDKCsT564"
   },
   "source": [
    "Class imbalance is a real problem that if not detected early on, gives the wrong impression that your model is performing better than it actually is. For this reason,  is important to rely on several metrics that do a better job at capturing these kinds of issues. **In this case the standard `accuracy` metric is misleading** and provides a false sense that the model is performing better than it actually is.\n",
    "\n",
    "To prove this point further consider a model that only predicts cats:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Yv65fC5NK5sV"
   },
   "outputs": [],
   "source": [
    "# Predict cat for all images\n",
    "all_cats = np.ones(y_true.shape)\n",
    "\n",
    "# Print accuracy score\n",
    "print(f\"Accuracy Score: {accuracy_score(y_true, all_cats)}\")\n",
    "\n",
    "# Print balanced accuracy score\n",
    "print(f\"Balanced Accuracy Score: {balanced_accuracy_score(y_true, all_cats)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "g_Gp6mYcIQlW"
   },
   "source": [
    "If you only look at the `accuracy` metric the model seems to be working fairly well, since the majority class is the same that the model always predicts.\n",
    "\n",
    "There are several techniques to deal with class imbalance. A very popular one is `SMOTE`, which oversamples the minority classes by creating syntethic data. However, these techniques are outside the scope of this lab.\n",
    "\n",
    "The previous metrics were computed with class imbalance both on the training and evaluation sets. If you are wondering how the model performed with class imbalance only on the training set run the following cell to see the confusion matrix with balanced classes in the evaluation set:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "r6xecVSuqMLx"
   },
   "outputs": [],
   "source": [
    "# Use the validation generator without shuffle to easily compute additional metrics\n",
    "val_gen_no_shuffle = test_datagen.flow_from_directory(\n",
    "    '/tmp/data/eval',\n",
    "    target_size=(150, 150),\n",
    "    batch_size=32,\n",
    "    class_mode='binary',\n",
    "    shuffle=False)\n",
    "\n",
    "# Get the true labels from the generator\n",
    "y_true = val_gen_no_shuffle.classes\n",
    "\n",
    "# Use the model to predict (will take a couple of minutes)\n",
    "predictions_imbalanced = imbalanced_model.predict(val_gen_no_shuffle)\n",
    "\n",
    "# Get the argmax (since softmax is being used)\n",
    "y_pred_imbalanced = np.argmax(predictions_imbalanced, axis=1)\n",
    "\n",
    "# Confusion matrix\n",
    "imbalanced_cm = confusion_matrix(y_true, y_pred_imbalanced)\n",
    "ConfusionMatrixDisplay(imbalanced_cm, display_labels=['birds', 'cats', 'dogs']).plot(values_format=\"d\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "R5vJRVjlQvK-"
   },
   "source": [
    "# Training with the complete dataset\n",
    "\n",
    "For the time being and following the narrative, assume that a colleague of yours was careful enough to save a backup of the complete dataset in her cloud storage. Now you can try training without the class imbalance issue, what a relief!\n",
    "\n",
    "Now that you have the complete dataset it is time to try again without suffering from class imbalance. **In general, collecting more data is beneficial for models!**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "w5VwUrpGPhH_"
   },
   "outputs": [],
   "source": [
    "# Create a model to use with the balanced dataset\n",
    "balanced_model = create_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FWFrVUmsmzzs"
   },
   "outputs": [],
   "source": [
    "# Still no data augmentation, only re-scaling\n",
    "train_datagen = ImageDataGenerator(rescale=1./255)\n",
    "test_datagen = ImageDataGenerator(rescale=1./255)\n",
    "\n",
    "# Generators now point to the complete dataset\n",
    "train_generator = train_datagen.flow_from_directory(\n",
    "        '/tmp/data/train',\n",
    "        target_size=(150, 150),\n",
    "        batch_size=32,\n",
    "        class_mode='binary')\n",
    "\n",
    "validation_generator = test_datagen.flow_from_directory(\n",
    "        '/tmp/data/eval',\n",
    "        target_size=(150, 150),\n",
    "        batch_size=32,\n",
    "        class_mode='binary')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WC7-I1ylr-_n"
   },
   "outputs": [],
   "source": [
    "# Load pretrained model and history\n",
    "\n",
    "balanced_history = pd.read_csv('history-balanced/history-balanced.csv')\n",
    "balanced_model = tf.keras.models.load_model('model-balanced')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NcOh1NVtm5Dg"
   },
   "outputs": [],
   "source": [
    "# Run only if you want to train the model yourself (this takes around 20 mins with GPU enabled)\n",
    "\n",
    "# balanced_history = balanced_model.fit(\n",
    "#     train_generator,\n",
    "#     steps_per_epoch=100,\n",
    "#     epochs=50,\n",
    "#     validation_data=validation_generator,\n",
    "#     validation_steps=80)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i7LZUa9RVvyX"
   },
   "source": [
    "Let's check how the `accuracy` vs `balanced accuracy` comparison looks like now:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EE3YiUW5WMOQ"
   },
   "outputs": [],
   "source": [
    "# Use the validation generator without shuffle to easily compute additional metrics\n",
    "val_gen_no_shuffle = test_datagen.flow_from_directory(\n",
    "    '/tmp/data/eval',\n",
    "    target_size=(150, 150),\n",
    "    batch_size=32,\n",
    "    class_mode='binary',\n",
    "    shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wfLgvWRfKuTQ"
   },
   "outputs": [],
   "source": [
    "# Get the true labels from the generator\n",
    "y_true = val_gen_no_shuffle.classes\n",
    "\n",
    "# Use the model to predict (will take a couple of minutes)\n",
    "predictions_balanced = balanced_model.predict(val_gen_no_shuffle)\n",
    "\n",
    "# Get the argmax (since softmax is being used)\n",
    "y_pred_balanced = np.argmax(predictions_balanced, axis=1)\n",
    "\n",
    "# Print accuracy score\n",
    "print(f\"Accuracy Score: {accuracy_score(y_true, y_pred_balanced)}\")\n",
    "\n",
    "# Print balanced accuracy score\n",
    "print(f\"Balanced Accuracy Score: {balanced_accuracy_score(y_true, y_pred_balanced)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7Mpnmv5YKyeD"
   },
   "outputs": [],
   "source": [
    "balanced_cm = confusion_matrix(y_true, y_pred_balanced)\n",
    "ConfusionMatrixDisplay(balanced_cm, display_labels=['birds', 'cats', 'dogs']).plot(values_format=\"d\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Dp7QCgZ0Wuf3"
   },
   "source": [
    "Both accuracy-based metrics are very similar now. The confusion matrix also looks way better than before. This suggests that class imbalance has been successfully mitigated by adding more data to the previously undersampled classes.\n",
    "\n",
    "Now that you now that you can trust the `accuracy` metric, let's plot the training history:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6pr2VmKtJpet"
   },
   "outputs": [],
   "source": [
    "plot_train_eval(balanced_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YCH1hTj7JvHu"
   },
   "source": [
    "This looks much better than for the imbalanced case! However, overfitting is still present.\n",
    "\n",
    "Can you think of ways to address this issue? If you are familiar with CNN's you might think of adding `dropout` layers. This intuition is correct but for the time being you decide to stick with the same model and only change the data to see if it is possible to mitigate overfitting in this manner.\n",
    "\n",
    "Another possible solution is to apply data augmentation techniques. Your whole team agrees this is the way to go so you decide to try this next!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VdlVWEZuX4ii"
   },
   "source": [
    "# Training with Data Augmentation\n",
    "\n",
    "Augmenting images is a technique in which you create new versions of the images you have at hand, by applying geometric transformations. These transformations can vary from: zooming in and out, rotating, or even flipping the images. By doing this, you get a training dataset that exposes the model to a wider variety of images. This helps in further exploring the feature space and hence reducing the chances of overfitting. \n",
    "\n",
    "It is also a very natural idea since doing slight (or sometimes not so slight) changes to an image will result in an equally valid image. A cat sitting in an awkward position is still a cat, right?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "V1EUr1eTVXEz"
   },
   "outputs": [],
   "source": [
    "# Create a model to use with the balanced and augmented dataset\n",
    "augmented_model = create_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "g7RAqkSRC98K"
   },
   "outputs": [],
   "source": [
    "# Now applying image augmentation\n",
    "train_datagen = ImageDataGenerator(\n",
    "        rescale=1./255,\n",
    "        rotation_range=50,\n",
    "        width_shift_range=0.15,\n",
    "        height_shift_range=0.15,\n",
    "        shear_range=0.2,\n",
    "        zoom_range=0.2,\n",
    "        horizontal_flip=True)\n",
    "\n",
    "\n",
    "test_datagen = ImageDataGenerator(rescale=1./255)\n",
    "\n",
    "# Still pointing to directory with full dataset\n",
    "train_generator = train_datagen.flow_from_directory(\n",
    "        '/tmp/data/train',\n",
    "        target_size=(150, 150),\n",
    "        batch_size=32,\n",
    "        class_mode='binary')\n",
    "\n",
    "validation_generator = test_datagen.flow_from_directory(\n",
    "        '/tmp/data/eval',\n",
    "        target_size=(150, 150),\n",
    "        batch_size=32,\n",
    "        class_mode='binary')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DnmSteAYA4s3"
   },
   "source": [
    "Notice that the only difference with the previous training is that the `ImageDataGenerator` object now has some extra parameters. We encourage you to read more about this topic [here](https://keras.io/api/preprocessing/image/) if you haven't already. Also **this was only done to the training generator since this technique should only be applied to the training images.**\n",
    "\n",
    "\n",
    "But what exactly are these extra parameters doing? \n",
    "\n",
    "Let's see these transformations in action. The following cell applies and displays different transformations for a single image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iiu_u0iRqgFM"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.image import img_to_array, array_to_img, load_img\n",
    "\n",
    "\n",
    "# Displays transformations on random images of birds in the training partition\n",
    "def display_transformations(gen):\n",
    "  train_birds_dir = \"/tmp/data/train/birds\"\n",
    "  random_index = random.randint(0, len(os.listdir(train_birds_dir)))\n",
    "  sample_image = load_img(f\"{os.path.join(train_birds_dir, os.listdir(train_birds_dir)[random_index])}\", target_size=(150, 150))\n",
    "  sample_array = img_to_array(sample_image)\n",
    "  sample_array = sample_array[None, :]\n",
    "\n",
    "\n",
    "  for iteration, array in zip(range(4), gen.flow(sample_array, batch_size=1)):\n",
    "    array = np.squeeze(array)\n",
    "    img = array_to_img(array)\n",
    "    print(f\"\\nTransformation number: {iteration}\\n\")\n",
    "    display(img)\n",
    "\n",
    "\n",
    "# An example of an ImageDataGenerator\n",
    "sample_gen = ImageDataGenerator(\n",
    "        rescale=1./255,\n",
    "        rotation_range=50,\n",
    "        width_shift_range=0.25,\n",
    "        height_shift_range=0.25,\n",
    "        shear_range=0.2,\n",
    "        zoom_range=0.25,\n",
    "        horizontal_flip=True)\n",
    "\n",
    "display_transformations(sample_gen)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OUNLR1NFBED3"
   },
   "source": [
    "Let's look at another more extreme example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "biDxKkdx09bg"
   },
   "outputs": [],
   "source": [
    "# An ImageDataGenerator with more extreme data augmentation\n",
    "sample_gen = ImageDataGenerator(\n",
    "        rescale=1./255,\n",
    "        rotation_range=90,\n",
    "        width_shift_range=0.3,\n",
    "        height_shift_range=0.3,\n",
    "        shear_range=0.5,\n",
    "        zoom_range=0.5,\n",
    "        vertical_flip=True,\n",
    "        horizontal_flip=True)\n",
    "\n",
    "display_transformations(sample_gen)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KaKZ624jBlt6"
   },
   "source": [
    "Feel free to try your own custom ImageDataGenerators! The results can be very fun to watch. If you check the [docs](https://keras.io/api/preprocessing/image/) there are some other parameters you may want to toy with.\n",
    "\n",
    "Now that you know what data augmentation is doing to the training images let's move onto training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6vO9TP1dJ5My"
   },
   "outputs": [],
   "source": [
    "# Load pretrained model and history\n",
    "\n",
    "augmented_history = pd.read_csv('history-augmented/history-augmented.csv')\n",
    "augmented_model = tf.keras.models.load_model('model-augmented')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7aSV4CyGHRz-"
   },
   "outputs": [],
   "source": [
    "# Run only if you want to train the model yourself (this takes around 20 mins with GPU enabled)\n",
    "\n",
    "# augmented_history = augmented_model.fit(\n",
    "#     train_generator,\n",
    "#     steps_per_epoch=100,\n",
    "#     epochs=80,\n",
    "#     validation_data=validation_generator,\n",
    "#     validation_steps=80)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d0hoorf7brwZ"
   },
   "source": [
    "Since you know that class imbalance is no longer an issue there is no need to check for more in-depth metrics. \n",
    "\n",
    "Let's plot the training history right away:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8EYc1oXmHjE2"
   },
   "outputs": [],
   "source": [
    "plot_train_eval(augmented_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nBy1VcxacPEx"
   },
   "source": [
    "Now, the evaluation accuracy follows more closely the training one. This indicates that **the model is no longer overfitting**. Quite a remarkable finding, achieved by just augmenting the data set. Another option to handle overfitting is to include dropout layers in your model as mentioned earlier.\n",
    "\n",
    "Another point worth mentioning, is that this model achieves a slightly lower evaluation accuracy when compared to the model without data augmentation. The reason for this, is that this model needs more epochs to train. To spot this issue, check that for the model without data augmentation, the training accuracy reached almost 100%, whereas the augmented one can still improve. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dOA93ENHczla"
   },
   "source": [
    "## Wrapping it up\n",
    "\n",
    "**Congratulations on finishing this ungraded lab!** \n",
    "\n",
    "It is quite amazing to see how data alone can impact Deep Learning models. Hopefully this lab helped you have a better understanding of the importance of data. \n",
    "\n",
    "In particular, you figured out ways to diagnose the effects of class imbalance and looked at specific metrics to spot this problem. Adding more data is a simple way to overcome class imbalance. However, this is not always feasible in a real life scenario.\n",
    "\n",
    "In the final section, you applied multiple geometric transformations to the images in the training dataset, to generate an augmented version. The goal was to use data augmentation to reduce overfitting. Changing the network architecture is an alternative method to reduce overfitting. In practice, it is a good idea to implement both techniques for better results.\n",
    "\n",
    "\n",
    "**Keep it up!**"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "C1W2_Ungraded_Lab_Birds_Cats_Dogs.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
